/**
 * \file dnn/src/cuda/reduce_helper/largeBC.cuinl
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "src/cuda/reduce_helper.cuh"
#include "src/cuda/cub/util_ptx.cuh"

#include <algorithm>
#include <cstdio>

namespace megdnn {
namespace cuda {
namespace reduce_intl {

struct ExecPolicy {
    // (BY, BX) is the blockDim to launch reduce kernel
    ExecPolicy(size_t A, size_t B, size_t C):
        A(A), B(B), C(C)
    {
        // use C to determine BX
        BX = 1;
        while (BX < 32 && BX < C) BX *= 2;
        BY = 512 / BX;
        NA = A;
        factor = BY*4;
        NB = DIVUP(B, factor);
        NC = DIVUP(C, BX);
        {
            nr_reduces = 0;
            size_t tmp = B;
            while (tmp > 1) {
                tmp = DIVUP(tmp, factor);
                ++nr_reduces;
            }
            if (nr_reduces == 0) nr_reduces = 1;
        }
    }
    ExecPolicy next() const
    {
        return ExecPolicy(A, DIVUP(B, factor), C);
    }
    size_t factor;
    size_t nr_reduces;
    size_t BY, BX;
    size_t NA, NB, NC;
    size_t A, B, C;
};

// Whenever blockIdx is referenced, bidy_offset and bidz_offset should be added.
// This mechanism is to solve thread block size limitation issue by calling
// multiple kernels from host code.
template <class Operator, class Reader, class Writer, typename wtype,
         uint32_t BX, uint32_t BY, bool sync_within_warp>
__global__ void kern_largeBC(
        Operator opr, Reader rdr, Writer wtr,
        uint32_t A, uint32_t B, uint32_t B2, uint32_t C,
        uint32_t bidy_offset, uint32_t bidz_offset)
{
    volatile __shared__ wtype shared[BY][BX];
    wtype s = opr.INIT;
    uint32_t c = threadIdx.x + blockIdx.x * blockDim.x;
    uint32_t a = blockIdx.z+bidz_offset;
    if (c < C) {
        uint32_t base = threadIdx.y + (blockIdx.y+bidy_offset)*4*blockDim.y;
        if (base + 0*blockDim.y < B) {
            s = opr.apply(s, rdr.read(a*B*C + (base + 0*blockDim.y)*C + c));
        }
        if (base + 1*blockDim.y < B) {
            s = opr.apply(s, rdr.read(a*B*C + (base + 1*blockDim.y)*C + c));
        }
        if (base + 2*blockDim.y < B) {
            s = opr.apply(s, rdr.read(a*B*C + (base + 2*blockDim.y)*C + c));
        }
        if (base + 3*blockDim.y < B) {
            s = opr.apply(s, rdr.read(a*B*C + (base + 3*blockDim.y)*C + c));
        }
    }
    shared[threadIdx.y][threadIdx.x] = s;
    __syncthreads();

    const uint32_t warp_y = 32 / BX;
#pragma unroll
    for (uint32_t k = 256; k > warp_y; k >>= 1) {
        if (BY >= k<<1) {
            if (threadIdx.y < k) {
                shared[threadIdx.y][threadIdx.x] = opr.apply(
                        shared[threadIdx.y][threadIdx.x],
                        shared[threadIdx.y+k][threadIdx.x]);
            }
            __syncthreads();
        }
    }
    if (threadIdx.y < warp_y) {
#pragma unroll
        for (uint32_t k = warp_y; k > 0; k >>= 1) {
            if (threadIdx.y < k) {
                shared[threadIdx.y][threadIdx.x] =
                        opr.apply(shared[threadIdx.y][threadIdx.x],
                                  shared[threadIdx.y + k][threadIdx.x]);
            }
            if (sync_within_warp) {
                __syncthreads();
            }
            /**
             * \warning Since CUDA 9.0, for Volta and Turing architecture,
             * applications that assume reads and writes are implicitly visible
             * to other threads in same warp need to insert the new __syncwarp()
             * warp-wide barrier synchronization instruction between steps where
             * data is exchanged between threads via global or shared memory.
             * For details, please refer to
             * https://docs.nvidia.com/cuda/volta-tuning-guide/index.html
             */
            cub::WARP_SYNC(0xffffffff);
        }
    }
    if (threadIdx.y == 0 && c < C) {
        uint32_t b2 = blockIdx.y+bidy_offset;
        wtr.write(a*B2*C + b2*C + c, shared[0][threadIdx.x]);
    }
}

/**
 * \tparam Operator must have method wtype apply(wtype, wtype)
 * \tparam Operator must have const member INIT
 * \tparam Reader must have method wtype read(size_t idx)
 * \tparam Writer must have method void write(size_t idx, wtype)
 */
template <class Operator, class Reader, class Writer, typename wtype,
         bool sync_within_warp>
void invoke_kernel(const ExecPolicy &p,
        const Operator &opr,
        const Reader &rdr,
        const Writer &wtr,
        cudaStream_t stream)
{
    // 32768 thread blocks for each call
#define CHECK(nBX) \
    if (p.BX == nBX && p.BY == 512/nBX) { \
        for (size_t bidy_offset = 0; bidy_offset < p.NB; bidy_offset += 32768) \
        for (size_t bidz_offset = 0; bidz_offset < p.NA; bidz_offset += 32768) \
        { \
            dim3 blocks; \
            blocks.x = p.NC; \
            blocks.y = std::min<size_t>(32768, p.NB - bidy_offset); \
            blocks.z = std::min<size_t>(32768, p.NA - bidz_offset); \
            kern_largeBC<Operator, Reader, Writer, wtype, nBX, 512/nBX, \
                sync_within_warp><<<blocks, dim3(p.BX, p.BY), 0, stream>>>( \
                        opr, rdr, wtr, p.A, p.B, DIVUP(p.B, p.factor), p.C, \
                        bidy_offset, bidz_offset); \
        } \
    }
    CHECK(1);
    CHECK(2);
    CHECK(4);
    CHECK(8);
    CHECK(16);
    CHECK(32);
#undef CHECK
    after_kernel_launch();
}

/**
 * inherit from PublicOperator
 */
template <class PublicOperator>
struct PublicReader {
    PublicOperator opr;
    typedef typename PublicOperator::wtype wtype;
    PublicReader(const PublicOperator &opr): opr(opr)
    {}
    __device__ wtype read(uint32_t idx)
    { return opr.read(idx); }
};

/**
 * read from workspace
 */
template <typename wtype>
struct WorkspaceReader {
    wtype *workspace;
    WorkspaceReader(wtype *workspace): workspace(workspace)
    {}
    __device__ wtype read(uint32_t idx)
    { return workspace[idx]; }
};

/**
 * inherit from PublicOperator
 */
template <class PublicOperator>
struct PublicWriter {
    PublicOperator opr;
    typedef typename PublicOperator::wtype wtype;
    PublicWriter(const PublicOperator &opr): opr(opr)
    {}
    __device__ void write(uint32_t idx, wtype value)
    { opr.write(idx, value); }
};

/**
 * write to workspace
 */
template <typename wtype>
struct WorkspaceWriter {
    wtype *workspace;
    WorkspaceWriter(wtype *workspace): workspace(workspace)
    {}
    __device__ void write(uint32_t idx, wtype value)
    { workspace[idx] = value; }
};

/**
 * \tparam PublicOperator
 *      must have typedef for wtype
 *      must have const static member wtype INIT
 *      must have method wtype read(uint32_t idx)
 *      must have method wtype apply(const wtype &, const wtype &)
 *      must have method void write(uint32_t idx, const wtype &)
 */
template <class PublicOperator, bool sync_within_warp>
void run_largeBC(typename PublicOperator::wtype *workspace,
        size_t A, size_t B, size_t C,
        cudaStream_t stream, const PublicOperator &opr)
{
    typedef typename PublicOperator::wtype wtype;
    using namespace reduce_intl;
    ExecPolicy p(A, B, C);
    if (p.nr_reduces == 1) {
        PublicReader<PublicOperator> rdr(opr);
        PublicWriter<PublicOperator> wtr(opr);
        invoke_kernel<PublicOperator,
            PublicReader<PublicOperator>,
            PublicWriter<PublicOperator>,
            wtype,
            sync_within_warp>(p, opr, rdr, wtr, stream);
    } else if (p.nr_reduces == 2) {
        PublicReader<PublicOperator> rdr1(opr);
        WorkspaceWriter<wtype> wtr1(workspace);
        WorkspaceReader<wtype> rdr2(workspace);
        PublicWriter<PublicOperator> wtr2(opr);
        invoke_kernel<PublicOperator,
            PublicReader<PublicOperator>,
            WorkspaceWriter<wtype>,
            wtype,
            sync_within_warp>(p, opr, rdr1, wtr1, stream);
        p = p.next();
        invoke_kernel<PublicOperator,
            WorkspaceReader<wtype>,
            PublicWriter<PublicOperator>,
            wtype,
            sync_within_warp>(p, opr, rdr2, wtr2, stream);
    } else {
        wtype *workspace1 = workspace;
        size_t B2 = DIVUP(B, p.factor);
        wtype *workspace2 = workspace + A * B2 * C;
        size_t nr_reduces = p.nr_reduces;

        {
            PublicReader<PublicOperator> rdr(opr);
            WorkspaceWriter<wtype> wtr(workspace1);
            invoke_kernel<PublicOperator,
                PublicReader<PublicOperator>,
                WorkspaceWriter<wtype>,
                wtype,
                sync_within_warp>(p, opr, rdr, wtr, stream);
        }
        p = p.next();
        wtype *current = workspace1;
        wtype *next = workspace2;
        for (size_t i = 1; i < nr_reduces; ++i) {
            WorkspaceReader<wtype> rdr(current);
            if (i + 1 == nr_reduces) {
                PublicWriter<PublicOperator> wtr(opr);
                invoke_kernel<PublicOperator,
                    WorkspaceReader<wtype>,
                    PublicWriter<PublicOperator>,
                    wtype,
                    sync_within_warp>(p, opr, rdr, wtr, stream);
            } else {
                WorkspaceWriter<wtype> wtr(next);
                invoke_kernel<PublicOperator,
                    WorkspaceReader<wtype>,
                    WorkspaceWriter<wtype>,
                    wtype,
                    sync_within_warp>(p, opr, rdr, wtr, stream);
            }
            std::swap(next, current);
            p = p.next();
        }
    }
}

template <typename wtype>
size_t get_workspace_largeBC(size_t A, size_t B, size_t C)
{
    using namespace reduce_intl;
    ExecPolicy p(A, B, C);
    if (p.nr_reduces == 1) {
        // direct reduce
        return 0;
    } else if (p.nr_reduces == 2) {
        // src->workspace->dst
        size_t B2 = DIVUP(B, p.factor);
        return sizeof(wtype) * A * B2 * C;
    } else {
        // src->workspace1->workspace2->dst
        size_t B2 = DIVUP(B, p.factor);
        size_t B3 = DIVUP(B2, p.factor);
        return sizeof(wtype) * A * B2 * C + sizeof(wtype) * A * B3 * C;
    }
}


} // namespace reduce_intl
} // namespace cuda
} // namespace megdnn

// vim: ft=cpp syntax=cpp.doxygen
